Backpropagation Through Time (BPTT)

"Just a fancy name for standard backpropagation on an unrolled RNN."

Formula

Suppose the network is defined as: $$ \begin{aligned} s_t &= \tanh(Ux_t + Ws_{t-1}) \\ \hat{y}_t &= \text{softmax}(Vs_t) \end{aligned} $$ The error is given by: $$ \begin{aligned} \mathcal{L}(y, \hat{y}) &= \sum_{t} \mathcal{L}_t(y_t, \hat{y_t}) = -\sum_{t} y_t \log \hat{y_t} \end{aligned} $$ We treat the full sequence(e.g. sentence) as one training example, so the total error is just the sum of the errors at each time step(e.g. word).

The gradient: $$ \begin{aligned} \frac{\partial \mathcal{L}}{\partial W} &= \sum_{t} \frac{\partial \mathcal{L}_t}{\partial W} \end{aligned} $$ For $V$, the gradient is independent from the states: $$ \begin{aligned} \frac{\partial \mathcal{L}_3}{\partial V} &= \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial V} = \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial z_3} \cdot \frac{\partial z_3}{\partial V} = \frac{\partial \mathcal{L}_3}{\partial z_3} \cdot \frac{\partial z_3}{\partial V} = (\hat{y_3} - y_3) \cdot s_3 \end{aligned} $$ For $W$, the situation is more complicated: $$ \begin{aligned} \frac{\partial \mathcal{L}_3}{\partial W} &= \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial s_3} \cdot \frac{\partial s_3}{\partial W} \\ &= \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial s_3} \cdot \sum_{k=0}^{3} \frac{\partial s_3}{\partial s_k} \cdot \frac{\partial s_k}{\partial W} \\ &= \sum_{k=0}^{3} \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial s_3} \cdot \frac{\partial s_3}{\partial s_k} \cdot \frac{\partial s_k}{\partial W} \\ &= \sum_{k=0}^{3} \frac{\partial \mathcal{L}_3}{\partial \hat{y_3}} \cdot \frac{\partial \hat{y_3}}{\partial s_3} \cdot \left(\prod_{j=k+1}^{3} \frac{\partial s_j}{\partial s_{j-1}}\right) \cdot \frac{\partial s_k}{\partial W} \\ \end{aligned} $$ where $\frac{\partial s_3}{\partial s_k}$ itself is a chain rule.

Vanishing Gradient

Consider a “simple” RNN with a tanh non-linearity: $$ h_t = \tanh(\mathbf{W}h_t + \mathbf{A}x_t) $$ Explain how vanishing gradients can occur by considering backprop through a hyperbolic tangent function. Would using sigmoid non-linearity instead help this function?

Suppose the loss function is $\mathcal{L}(y_t, \hat{y_t})$, where $\hat{y_t}$ is the output of the network at time $t$. According to backpropagation through time(BPTT), the gradient for $W$ is:

$$ \begin{aligned} \frac{\partial \mathcal{L}(y_t, \hat{y_t})}{\partial W} &= \frac{\partial \mathcal{L}(y_t, \hat{y_t})}{\partial \hat{y_t}} \cdot \frac{\partial \hat{y_t}}{\partial h_t} \cdot \frac{\partial h_t}{\partial W} \\ &= \sum_{k=0}^{t} \frac{\partial \mathcal{L}(y_t, \hat{y_t})}{\partial \hat{y_t}} \cdot \frac{\partial \hat{y_t}}{\partial h_t} \cdot \frac{\partial h_t}{\partial h_k} \cdot \frac{\partial h_k}{\partial W} \\ &= \sum_{k=0}^{t} \frac{\partial \mathcal{L}(y_t, \hat{y_t})}{\partial \hat{y_t}} \cdot \frac{\partial \hat{y_t}}{\partial h_t} \cdot \left(\prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}}\right) \cdot \frac{\partial h_k}{\partial W} \\ \end{aligned} $$

To compute the partial gradient of $h_t$, we need to know the gradient of $\tanh$, which is:

$$ \begin{aligned} \frac{d\tanh(x)}{dx} &= \frac{d\frac{\sinh(x)}{\cosh(x)}}{dx} = 1 - \tanh^2(x) \in [0, 1] \end{aligned} $$

Then:

$$ \begin{aligned} \frac{\partial h_j}{\partial h_{j-1}} &= \frac{\partial \tanh(\mathbf{W}h_{j-1} + \mathbf{A}x_t)}{\partial h_{j-1}} \\ &= \left[1 - \tanh^2(\mathbf{W}h_{j-1} + \mathbf{A}x_t) \right] \cdot \mathbf{W} \\ \left(\prod_{j=k+1}^{t} \frac{\partial h_j}{\partial h_{j-1}}\right) &= \mathbf{W}^{t-k} \cdot \prod_{j=k+1}^{t} \left[1 - \tanh^2(\mathbf{W}h_{j-1} + \mathbf{A}x_t) \right] \\ &\approx \mathbf{W}^{t-k} \cdot \epsilon^{t-k} \quad\text{ where } \epsilon \in [0, 1] \end{aligned} $$

With small values in the matrix $\mathbf{W}$ and large sequence size $t$, the gradient $\frac{\partial \mathcal{L}(y_t, \hat{y_t})}{\partial W}$ would shrink exponentially fast and eventually vanished due to the $(\epsilon^{t-k})$ term.

It is also possible to get exploding gradients due to the $(\mathbf{W}^{t-k})$ term. Nevertheless, this problem is easily detectable during training and can be prevented by setting a maximum value for the gradient in advance.

The gradient of sigmoid is:

$$ \begin{aligned} \frac{\partial \sigma(x)}{\partial x} &= \sigma(x) \cdot [1 - \sigma(x)] \in [0, 1] \\ \end{aligned} $$

The range of this gradient suggests that the sigmoid function would also suffer the vanshing gradient problem as the $\tanh$ function. Usually the $\tanh$ function is preferred because the expected gradient is larger for the same range of input.

Solution to the vanishing gradient problem include using specific activation functions such as ReLU(whose gradient is eight 0 or 1), or using LSTM and GRU architecture.

Reference:

by Jon